'''
Author: SlytherinGe
LastEditTime: 2021-08-31 21:30:23
'''
import xml.etree.ElementTree as ET
import os.path as osp
import os
if __name__ == '__main__':
    import voc_label_utility as util
else:
    import dataset_utilty.voc_label_utility as util

vedai_class_2_index_map = {
    'car' : '1',
    'truck' : '2',
    'ship' : '23',
    'tractor' : '4',
    'camping car' : '5',
    'bike' : '7',
    'others' : '8',
    'van' : '9',
    'vehicle' : '10',
    'pick-up' : '11',
    'plane' : '31',
}
vedai_index_2_class_map = {
    '1':'car',
    '2':'truck',
    '23':'ship' ,
    '4':'tractor',
    '5':'camping car',
    '7':'bike',
    '8':'others',
    '9':'van',
    '10':'vehicle',
    '11':'pick-up',
    '31':'plane',
}

def get_filename_and_bndbox_from_rotated_label(input_file):

    tree = ET.parse(input_file)
    root = tree.getroot()
    fname = root.find('source').find('filename').text
    objects = root.find('objects')
    obj_list = []
    for obj in objects:
        class_name = obj.find('possibleresult').find('name').text
        points = obj.find('points')
        pt_list = []
        for point in points:
            pt_str = point.text
            pt = pt_str.split(', ')[0], pt_str.split(', ')[1]
            pt_list.append(pt)
        x_min, y_min, x_max, y_max = pt_list[0][0], pt_list[0][1], pt_list[0][0], pt_list[0][1]
        for pt in pt_list:
            if pt[0] < x_min:
                x_min = pt[0]
            if pt[0] > x_max:
                x_max = pt[0]
            if pt[1] < y_min:
                y_min = pt[1]
            if pt[1] > y_max:
                y_max = pt[1]
        obj = {
            'name' : class_name,
            'bndbox':{
                'xmin': x_min,
                'ymin': y_min,
                'xmax': x_max,
                'ymax': y_max,
            }
        }
        obj_list.append(obj)
    return fname, obj_list    

def get_xml_tree_bbox_anno_from_bbox_list(fname, obj_list):
    root = ET.Element('annotation')
    tree = ET.ElementTree(root)
    fn, fd, sz = util.voc_basic_xml_element_generator(fname, 'VOC2012', 1000, 1000, 3)
    root.append(fn)
    root.append(fd)
    root.append(sz)
    for dict_obj in obj_list:
        dict_obj['truncated'] = '0'
        dict_obj['difficult'] = '0'
        ele_obj = util.voc_object_xml_element_generator(dict_obj)
        root.append(ele_obj)
    
    return tree


def get_filename_and_bndbox_from_VEDAI(input_file, width, height, channel):

    filename = os.path.basename(input_file)
    root = ET.Element('annotation')
    tree = ET.ElementTree(root)
    for obj in util.voc_basic_xml_element_generator(filename.split('.')[0], 'VOC2012', width, height, channel):
        root.append(obj)
    with open(input_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip('\n')
            # anno_data = [center_x, center_y, angle, class_index, truncated, difficult, ax, bx, cx, dx, ay, by, cy, dy]
            anno_data = line.split(' ')
            if len(anno_data) < 14:
                print('error anno file:', input_file)
                continue
            pt_list = []
            for i in range(4):
                x = int(anno_data[6+i])
                y = int(anno_data[10+i])
                pt_list.append((x, y))
            x_min, y_min, x_max, y_max = pt_list[0][0], pt_list[0][1], pt_list[0][0], pt_list[0][1]
            for pt in pt_list:
                if pt[0] < x_min:
                    x_min = pt[0]
                if pt[0] > x_max:
                    x_max = pt[0]
                if pt[1] < y_min:
                    y_min = pt[1]
                if pt[1] > y_max:
                    y_max = pt[1]
            obj_dict = {
                'name' : vedai_index_2_class_map[anno_data[3]],
                'truncated' : anno_data[4],
                'difficult' : anno_data[5],
                'bndbox':{
                    'xmin': x_min,
                    'ymin': y_min,
                    'xmax': x_max,
                    'ymax': y_max,
                }
            }
            obj_ele = util.voc_object_xml_element_generator(obj_dict) 
            root.append(obj_ele)
    return tree

def get_bbox_from_SSDDpp(input_file):
    root = ET.Element('annotation')
    tree = ET.ElementTree(root)
    anno_info, anno_obj = util.voc_label_preprocess(input_file)
    for info in anno_info:
        root.append(info)
    obj_ele = []
    for obj in anno_obj:
        dict_obj = util.voc_object_xml_element_resolver(obj)
        pt_list = [ (int(dict_obj['bndbox']['x1']), int(dict_obj['bndbox']['y1'])),
                    (int(dict_obj['bndbox']['x2']), int(dict_obj['bndbox']['y2'])),
                    (int(dict_obj['bndbox']['x3']), int(dict_obj['bndbox']['y3'])),
                    (int(dict_obj['bndbox']['x4']), int(dict_obj['bndbox']['y4']))]
        x_min, y_min, x_max, y_max = pt_list[0][0], pt_list[0][1], pt_list[0][0], pt_list[0][1]
        for pt in pt_list:
            if pt[0] < x_min:
                x_min = pt[0]
            if pt[0] > x_max:
                x_max = pt[0]
            if pt[1] < y_min:
                y_min = pt[1]
            if pt[1] > y_max:
                y_max = pt[1]
        dict_obj = {
            'name' : dict_obj['name'],
            'pose' : dict_obj['pose'],
            'truncated' : dict_obj['truncated'],
            'difficult' : dict_obj['difficult'],
            'bndbox':{
                'xmin': x_min,
                'ymin': y_min,
                'xmax': x_max,
                'ymax': y_max,
            }
        }
        obj_ele = util.voc_object_xml_element_generator(dict_obj) 
        root.append(obj_ele)
    return tree



if __name__ == '__main__':
    RANNO_PATH = '/media/gejunyao/Disk1/Customized Datasets/VEDAI_VOC/Origin_Anno_512'
    SAVE_PATH = '/media/gejunyao/Disk1/Customized Datasets/VEDAI_VOC/VOC_Anno_512'

    if not os.path.exists(SAVE_PATH):
        os.mkdir(SAVE_PATH)

    file_list = os.listdir(RANNO_PATH)
    for f in file_list:
        ele_tree = get_filename_and_bndbox_from_VEDAI(osp.join(RANNO_PATH, f), 1024, 1024, 3)
        ele_tree.write(osp.join(SAVE_PATH, '512_'+f.split('.')[0]+'.xml'), encoding='utf-8', xml_declaration=True)


# if __name__ == '__main__':
#     ele_tree = get_filename_and_bndbox_from_VEDAI('/media/gejunyao/Disk1/Customized Datasets/VEDAI_VOC/Origin_Anno_1024/00000006.txt', 1024, 1024, 3)
#     ele_tree.write('/media/gejunyao/Disk1/Customized Datasets/VEDAI_VOC/test.xml', encoding='utf-8', xml_declaration=True)
